package org.test4j.mock.faking.meta;

import lombok.Getter;

import java.util.*;

/**
 * 调用次数校验
 *
 * @author darui.wu
 */
public class TimesVerify {
    private final static ThreadLocal<Map<String, Integer>> invokedTimes = new InheritableThreadLocal<>();

    private final static ThreadLocal<List<TimesVerify>> timesVerifier = new InheritableThreadLocal<>();

    @Getter
    private final String classMethod;
    /**
     * 最少调用
     */
    private int min;
    /**
     * 最多调用
     */
    private int max;

    public TimesVerify(String classMethod, int min, int max) {
        int index = classMethod.indexOf(')');
        this.classMethod = classMethod.substring(0, index + 1);
        this.min = min;
        this.max = max;
    }

    public void verifyTimes() {
        int times = Optional.ofNullable(invokedTimes.get())
            .map(m -> m.get(this.classMethod))
            .orElse(0);

        if (times < min) {
            throw new AssertionError(String.format("Method[%s] are expected to be called at least %d times, but be called %d times.",
                this.classMethod, min, times));
        } else if (times > max) {
            throw new AssertionError(String.format("Method[%s] are expected to be called and at most %d times, but be called %d times.",
                this.classMethod, max, times));
        }
    }

    private static void add(TimesVerify timesVerify) {
        if (timesVerifier.get() == null) {
            return;
        }
        timesVerifier.get().add(timesVerify);
    }

    /**
     * 方法应该不被调用
     *
     * @param classMethod
     */
    public static void never(String classMethod) {
        add(new TimesVerify(classMethod, 0, 0));
    }

    /**
     * 调用次数不少于min次
     *
     * @param classMethod
     * @param min
     */
    public static void min(String classMethod, int min) {
        add(new TimesVerify(classMethod, min, Integer.MAX_VALUE));
    }

    /**
     * 调用次数不超过max次
     *
     * @param classMethod
     * @param max
     */
    public static void max(String classMethod, int max) {
        add(new TimesVerify(classMethod, 0, max));
    }

    /**
     * 调用次数在 [min, max]之间
     *
     * @param classMethod
     * @param min
     * @param max
     */
    public static void between(String classMethod, int min, int max) {
        add(new TimesVerify(classMethod, min, max));
    }

    /**
     * 恰好调用times次
     *
     * @param classMethod
     * @param times
     */
    public static void just(String classMethod, int times) {
        add(new TimesVerify(classMethod, times, times));
    }

    /**
     * 递增调用次数
     *
     * @param methodId
     */
    public static void increaseTimes(MethodId methodId) {
        if (invokedTimes.get() == null) {
            return;
        }
        String classMethod = methodId.realClassDesc + "#" + methodId.name + methodId.descNoInvocation;
        if (classMethod.startsWith("org/junit/runner") || classMethod.startsWith("org/test4j/integration")) {
            return;
        }
        int index = classMethod.indexOf(')');
        classMethod = classMethod.substring(0, index + 1);

        Integer times = invokedTimes.get().get(classMethod);
        times = times == null ? 1 : times + 1;
        invokedTimes.get().put(classMethod, times);
    }

    public static void initVerify() {
        if (timesVerifier.get() == null) {
            timesVerifier.set(new ArrayList<>());
        }
        if (invokedTimes.get() == null) {
            invokedTimes.set(new HashMap<>());
        }
        timesVerifier.get().clear();
        invokedTimes.get().clear();
    }

    /**
     * 调用次数断言
     */
    public static void verify() {
        if (timesVerifier.get() == null) {
            return;
        }
        for (TimesVerify verify : timesVerifier.get()) {
            verify.verifyTimes();
        }
    }
}